In [1]:
import torch
import torch.nn as nn
# from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torchvision
import torch.autograd as autograd
from PIL import Image
import imp
import os
import sys
import math
import time
import random
import shutil
# import cv2
import scipy.misc
from glob import glob
import sklearn
import logging

from time import time
from tqdm import tqdm
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')

%matplotlib inline
In [2]:
if not torch.cuda.is_available():
    print("SORRY: No CUDA device")
In [3]:
imageSize = 64
batchSize = 64
In [4]:
nz = 100
ngf = 64
ndf = 64
nc = 3

nd = 3
cuda = True

Load data

In [5]:
PATH = 'celeba/'

data = dset.ImageFolder(PATH,
    transforms.Compose([
        transforms.Scale(imageSize),
        transforms.CenterCrop(imageSize),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

dataloader = DataLoader(data, batch_size=batchSize, shuffle=True)

Custom weights initialization called on netG and netD

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
In [26]:
class _netG(nn.Module):
    def __init__(self):
        super(_netG, self).__init__()
        self.ngpu = 1
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output
    
netG = _netG()
netG.apply(weights_init)
netG
Out[26]:
_netG (
  (main): Sequential (
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU (inplace)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU (inplace)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU (inplace)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU (inplace)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh ()
  )
)
In [27]:
class _netD(nn.Module):
    def __init__(self):
        super(_netD, self).__init__()
        self.ngpu = 1        
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        
        return output
        return output.view(-1, 1)

netDs = [_netD() for _ in range(nd)]
[netD.apply(weights_init) for netD in netDs]
Out[27]:
[_netD (
   (main): Sequential (
     (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (1): LeakyReLU (0.2, inplace)
     (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
     (4): LeakyReLU (0.2, inplace)
     (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
     (7): LeakyReLU (0.2, inplace)
     (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
     (10): LeakyReLU (0.2, inplace)
     (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
     (12): Sigmoid ()
   )
 ), _netD (
   (main): Sequential (
     (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (1): LeakyReLU (0.2, inplace)
     (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
     (4): LeakyReLU (0.2, inplace)
     (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
     (7): LeakyReLU (0.2, inplace)
     (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
     (10): LeakyReLU (0.2, inplace)
     (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
     (12): Sigmoid ()
   )
 ), _netD (
   (main): Sequential (
     (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (1): LeakyReLU (0.2, inplace)
     (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
     (4): LeakyReLU (0.2, inplace)
     (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
     (7): LeakyReLU (0.2, inplace)
     (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
     (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
     (10): LeakyReLU (0.2, inplace)
     (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
     (12): Sigmoid ()
   )
 )]

Setup input and output tensors

In [28]:
criterion = nn.BCELoss()

input = torch.FloatTensor(batchSize, 3, imageSize, imageSize)
noise = torch.FloatTensor(batchSize, nz, 1, 1)
fixed_noise = torch.FloatTensor(batchSize, nz, 1, 1).normal_(0, 1)
label = torch.FloatTensor(batchSize)
real_label = 1
fake_label = 0

if cuda:
    [netD.cuda() for netD in netDs]
    netG.cuda()
    criterion.cuda()
    input, label = input.cuda(), label.cuda()
    noise, fixed_noise = noise.cuda(), fixed_noise.cuda()

input = Variable(input)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)

Setup optimizer

In [29]:
lr = 0.0002
beta1 = 0.5
In [30]:
optimizerDs = [optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) for netD in netDs]
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Train

In [31]:
niter = 1
In [32]:
fake = netG(noise)
In [33]:
%timeit 
[netD(fake).sum().data.cpu().numpy()[0] for netD in netDs]
Out[33]:
[30.592772, 53.0947, 47.219414]
In [34]:
%%time
[netD(fake).sum() for netD in netDs]
CPU times: user 46.9 ms, sys: 12.4 ms, total: 59.4 ms
Wall time: 58.3 ms
Out[34]:
[Variable containing:
  30.5928
 [torch.cuda.FloatTensor of size 1 (GPU 0)], Variable containing:
  53.0947
 [torch.cuda.FloatTensor of size 1 (GPU 0)], Variable containing:
  47.2194
 [torch.cuda.FloatTensor of size 1 (GPU 0)]]
In [104]:
1.00000e-02 *7.7859,  1.00000e-03 *3.3553
Out[104]:
(0.077859, 0.0033553000000000003)
In [ ]:
 
In [35]:
tick = time()
losses = []
for epoch in range(niter):
    for i, data in enumerate(dataloader):
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        
        for netD, optimizerD in zip(netDs, optimizerDs):
            netD.zero_grad()
            real_cpu, _ = data
            batch_size = real_cpu.size(0)
            input.data.resize_(real_cpu.size()).copy_(real_cpu)
            label.data.resize_(batch_size).fill_(real_label)

            output = netD(input)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.data.mean()

            # train with fake
            noise.data.resize_(batch_size, nz, 1, 1)
            noise.data.normal_(0, 1)
            fake = netG(noise)
            label.data.fill_(fake_label)
            output = netD(fake.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.data.mean()
            errD = errD_real + errD_fake
            
            optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        
        netG.zero_grad()
        
        # randomly select a D
        whichD = np.argmax([netD(fake).sum().data.cpu().numpy()[0] for netD in netDs])
        netD = netDs[whichD]
        
        output = netD(fake)
        label.data.fill_(real_label)  # fake labels are real for generator cost
        
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.data.mean()
        optimizerG.step()
        
        losses.append((errD.data[0], errG.data[0]))
        
        if i%100 == 0:
            print('[{}/{}][{}/{}][{:.2f}] Loss_D: {:.2f} Loss_G: {:.2f} D(x): {:.2f} D(G(z)): {:.2f} / {:.2f}'.\
            format(epoch, niter, i, len(dataloader), time() - tick, errD.data[0], errG.data[0], D_x, D_G_z1, D_G_z2))
        
        if i%100 == 0:
            fake = netG(fixed_noise).data
            show(vutils.make_grid(fake[:25], normalize=True, nrow=5).cpu())            
            #show(vutils.make_grid(fake, normalize=True).cpu())
[0/1][0/3166][0.69] Loss_D: 1.53 Loss_G: 3.45 D(x): 0.48 D(G(z)): 0.41 / 0.05
[0/1][100/3166][62.16] Loss_D: 0.14 Loss_G: 21.93 D(x): 0.93 D(G(z)): 0.00 / 0.00
[0/1][200/3166][123.80] Loss_D: 2.35 Loss_G: 2.70 D(x): 0.29 D(G(z)): 0.01 / 0.11
[0/1][300/3166][185.78] Loss_D: 0.18 Loss_G: 4.52 D(x): 0.88 D(G(z)): 0.02 / 0.04
[0/1][400/3166][247.87] Loss_D: 0.03 Loss_G: 2.77 D(x): 0.99 D(G(z)): 0.02 / 0.13
[0/1][500/3166][309.97] Loss_D: 0.25 Loss_G: 3.68 D(x): 0.85 D(G(z)): 0.02 / 0.07
[0/1][600/3166][372.12] Loss_D: 0.35 Loss_G: 6.64 D(x): 0.80 D(G(z)): 0.02 / 0.01
[0/1][700/3166][434.34] Loss_D: 0.16 Loss_G: 4.32 D(x): 0.89 D(G(z)): 0.02 / 0.04
[0/1][800/3166][496.57] Loss_D: 0.15 Loss_G: 9.23 D(x): 0.97 D(G(z)): 0.09 / 0.02
[0/1][900/3166][558.83] Loss_D: 0.57 Loss_G: 6.75 D(x): 0.95 D(G(z)): 0.31 / 0.00
[0/1][1000/3166][621.10] Loss_D: 0.05 Loss_G: 5.34 D(x): 0.98 D(G(z)): 0.03 / 0.01
[0/1][1100/3166][683.31] Loss_D: 0.41 Loss_G: 2.94 D(x): 0.78 D(G(z)): 0.07 / 0.09
[0/1][1200/3166][745.61] Loss_D: 0.14 Loss_G: 4.68 D(x): 0.93 D(G(z)): 0.05 / 0.04
[0/1][1300/3166][807.87] Loss_D: 0.07 Loss_G: 5.79 D(x): 0.97 D(G(z)): 0.04 / 0.01
[0/1][1400/3166][870.21] Loss_D: 3.12 Loss_G: 1.62 D(x): 0.15 D(G(z)): 0.01 / 0.30
[0/1][1500/3166][932.56] Loss_D: 0.36 Loss_G: 4.09 D(x): 0.93 D(G(z)): 0.22 / 0.03
[0/1][1600/3166][994.81] Loss_D: 0.39 Loss_G: 3.15 D(x): 0.92 D(G(z)): 0.19 / 0.15
[0/1][1700/3166][1057.09] Loss_D: 0.38 Loss_G: 6.21 D(x): 0.79 D(G(z)): 0.02 / 0.01
[0/1][1800/3166][1119.39] Loss_D: 0.16 Loss_G: 4.55 D(x): 0.91 D(G(z)): 0.01 / 0.03
[0/1][1900/3166][1181.65] Loss_D: 0.48 Loss_G: 3.21 D(x): 0.89 D(G(z)): 0.21 / 0.10
[0/1][2000/3166][1243.94] Loss_D: 0.27 Loss_G: 3.17 D(x): 0.88 D(G(z)): 0.08 / 0.07
[0/1][2100/3166][1306.21] Loss_D: 0.28 Loss_G: 3.53 D(x): 0.92 D(G(z)): 0.13 / 0.08
[0/1][2200/3166][1368.51] Loss_D: 0.50 Loss_G: 5.70 D(x): 0.87 D(G(z)): 0.23 / 0.03
[0/1][2300/3166][1430.77] Loss_D: 0.13 Loss_G: 3.53 D(x): 0.92 D(G(z)): 0.03 / 0.06
[0/1][2400/3166][1493.02] Loss_D: 0.17 Loss_G: 4.38 D(x): 0.90 D(G(z)): 0.03 / 0.03
[0/1][2500/3166][1555.38] Loss_D: 0.09 Loss_G: 4.54 D(x): 0.94 D(G(z)): 0.02 / 0.04
[0/1][2600/3166][1617.64] Loss_D: 0.45 Loss_G: 2.60 D(x): 0.72 D(G(z)): 0.03 / 0.13
[0/1][2700/3166][1679.91] Loss_D: 0.20 Loss_G: 5.06 D(x): 0.95 D(G(z)): 0.11 / 0.04
[0/1][2800/3166][1742.27] Loss_D: 0.24 Loss_G: 3.93 D(x): 0.85 D(G(z)): 0.01 / 0.06
[0/1][2900/3166][1804.56] Loss_D: 0.78 Loss_G: 1.72 D(x): 0.66 D(G(z)): 0.15 / 0.25
[0/1][3000/3166][1866.81] Loss_D: 0.56 Loss_G: 2.18 D(x): 0.70 D(G(z)): 0.04 / 0.20
[0/1][3100/3166][1929.14] Loss_D: 0.45 Loss_G: 2.96 D(x): 0.93 D(G(z)): 0.23 / 0.11

Show

In [19]:
def show(img, fs=(6,6)):
    plt.figure(figsize = fs)
    plt.imshow(np.transpose(img.numpy(), (1,2,0)))
    plt.show()
In [36]:
fake = netG(fixed_noise).data
In [37]:
show(vutils.make_grid(fake[:25], normalize=True, nrow=5).cpu())
In [ ]: